WIP: Add top_k compatibility
#158
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This references the PR data-apis/array-api-tests#274 and implements the compatibility layer for
top_k.Summary of Compatibility
jax:top_kdoes not implementaxisorlargestarguments. Whileaxisis easily implemented withjax.numpy.swapaxes,largestis not. Implementing the spec in JAX can be done similar to the pure python implementation in WIP: top_k draft implementation numpy/numpy#26666.jax.numpy.partitionandjax.numpy.argpartitionjax-ml/jax#22137.numpy:dask:top_kis currently about 2x longer than it has to be since computing the indices and values has to be done separately. This can be rectified whentake_along_axisis implemented in dask: Add NumPy's new take_along_axis dask/dask#3663.torch:Process
As mentioned in the referenced PR, since the process I went through is likely going to be repeated again, here are the steps I took:
array-apithat adds the corresponding specification..draft.array-api-testswhich implements the new tests and has itsarray-apisubmodule pointing to the newly createdarray-apibranch.array-api-compat(This PR) that implements the compatibility and points the CI to the newly createdarray-api-testsbranch.ARRAY_API_TESTS_VERSION=draftin the CI.Since I was implementing tests and compatibility on a non-existent spec, developing all 3 concurrently was incredibly messy. As of now I don't have much opinions on how to improve this process, but a documentation page of the necessary steps will be really helpful for future contributors.